import PIL
import torch
import numpy as np
from torchvision import transforms

__all__ = ['ToBatchTensor']

class ToBatchTensor(object):
    def __call__(self, clip):
        if isinstance(clip[0], (PIL.Image.Image, np.ndarray)):
            return torch.stack([transforms.ToTensor()(image) for image in clip])
        else:
            return torch.from_numpy(clip)